import torch
import numpy as np
import click
from bgflow.utils import (
    distance_vectors,
    distances_from_vectors,
)
from bgflow import (
    DiffEqFlow,
    BoltzmannGenerator,
    MeanFreeNormalDistribution,
    BlackBoxDynamics,
)
from bgflow.utils import assert_numpy
from bgflow.bg import sampling_efficiency
import tqdm
from eq_ot_flow.estimator import BruteForceEstimatorFast
from eq_ot_flow.LJ import LennardJonesPotential
from LJ_utils import get_data, get_dynamics
from path_grad_helpers import load_weights
import json
from torchdyn.core import NeuralODE


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def compute_foward_ess(log_weight_hist):
    """Forward ESS."""
    log_weight_hist = log_weight_hist.flatten()
    print("Log weight hist shape")
    with torch.no_grad():
        # weight_hist = calc_imp_weights(log_weight_hist)
        max_log_uw, _ = log_weight_hist.max(-1)
        w = torch.exp(log_weight_hist - max_log_uw)
        inv_w = 1 / w

    ess = 1 / w.mean() / inv_w.mean()

    return ess.item()


@click.command()
@click.option("--path", type=str, default="models")
@click.option("--n_particles", type=int, default=13)
@click.option("--data_path", default="/data")
@click.option("--n_holdout", default=100000)
@click.option("--n_sample_batches", default=20)
@click.option("--batch-size", default=10000)
@click.option("--n_knots_hutch", default=20)
def main(
    path, n_particles, n_sample_batches, data_path, n_holdout, batch_size, n_knots_hutch
):
    dim = n_particles * 3
    n_dimensions = dim // n_particles

    target = LennardJonesPotential(
        dim, n_particles, eps=1.0, rm=1, oscillator_scale=1, two_event_dims=False
    )

    data = get_data(data_path, n_particles)

    np.random.seed(0)
    idx = np.random.choice(np.arange(len(data)), len(data), replace=False)
    data_holdout = data[idx[n_holdout:]]
    print("Data Shape", data_holdout.shape)

    # now set up a prior
    prior = MeanFreeNormalDistribution(dim, n_particles, two_event_dims=False).to(
        device
    )

    # Build the Boltzmann Generator

    net_dynamics = get_dynamics(n_particles)
    bb_dynamics = BlackBoxDynamics(
        dynamics_function=net_dynamics, divergence_estimator=BruteForceEstimatorFast()
    )
    # Setting to rk4 since that is also what they did for LJ55
    flow = DiffEqFlow(
        dynamics=bb_dynamics, integrator="rk4", n_time_steps=n_knots_hutch
    )

    # having a flow and a prior, we can now define a Boltzmann Generator

    bg = BoltzmannGenerator(prior, flow, target.to(device))

    load_weights(bg, path)

    latent_np = np.empty(shape=(0))
    samples_np = np.empty(shape=(0))
    log_w_np = np.empty(shape=(0))

    energies_np = np.empty(shape=(0))
    distances_x_np = np.empty(shape=(0))

    for i in tqdm.tqdm(range(n_sample_batches)):
        with torch.no_grad():
            samples, latent, dlogp = bg.sample(
                batch_size, with_latent=True, with_dlogp=True
            )
            log_weights = (
                bg.log_weights_given_latent(samples, latent, dlogp, normalize=False)
                .detach()
                .cpu()
                .numpy()
            )
            latent_np = np.append(latent_np, latent.detach().cpu().numpy())
            samples_np = np.append(samples_np, samples.detach().cpu().numpy())
            distances_x = (
                distances_from_vectors(
                    distance_vectors(samples.view(-1, n_particles, n_dimensions))
                )
                .detach()
                .cpu()
                .numpy()
                .reshape(-1)
            )
            distances_x_np = np.append(distances_x_np, distances_x)

            log_w_np = np.append(log_w_np, log_weights)
            energies = target.energy(samples).detach().cpu().numpy()
            energies_np = np.append(energies_np, energies)

    latent_np = latent_np.reshape(-1, dim)
    samples_np = samples_np.reshape(-1, dim)
    ess = sampling_efficiency(torch.from_numpy(log_w_np)).item()
    print(f"ESS-Q {ess}")
    bg_nrjs_data = []
    nrjs_data = []
    with torch.no_grad():
        for i in tqdm.tqdm(range(0, data_holdout.shape[0], batch_size)):
            bg_nrjs_data.append(
                assert_numpy(bg.energy(data_holdout[i : i + batch_size].to(device)))
            )
            nrjs_data.append(
                assert_numpy(target.energy(data_holdout[i : i + batch_size].to(device)))
            )
    bg_energies_data = np.concatenate(bg_nrjs_data)
    energies_data = np.concatenate(nrjs_data)
    log_w_tilde_p = bg_energies_data - energies_data
    ess_p = compute_foward_ess(torch.from_numpy(log_w_tilde_p))
    print(f"ESS-P {ess_p}")

    node = NeuralODE(
        net_dynamics,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-4,
        rtol=1e-4,
    )

    with torch.no_grad():
        traj = node.trajectory(
            prior.sample(10000),
            t_span=torch.linspace(0, 1, 100),
        )
    path_lengths = torch.linalg.norm(traj[1:] - traj[:-1], dim=-1).sum(0)

    results = {
        "ess_q": float(ess),
        "KL(Q|P) - c": float(-np.mean(log_w_np)),
        "ess_p": float(ess_p),
        "KL(P|Q) + c": float(log_w_tilde_p.mean()),
        "-log Q": float(np.mean(bg_energies_data)),
        "Path Length": float(path_lengths.mean()),
    }
    print(results)
    json.dump(results, open(f"{path}-results.json", "w"))


if __name__ == "__main__":

    main()
